#!/usr/bin/env python3
from typing import Optional
from pathlib import Path
import sys; sys.path.append(str(Path(__file__).parent.parent.resolve()))

from plot_master import main, colors
import wandb

from rpi.scripts.pretraining.experts import (
    cheetah_ppo,
    cheetah_sac,
    walker_ppo,
    walker_sac,
    pendulum_ppo,
    pendulum_sac,
    cartpole_ppo,
    cartpole_sac,
)

def convert(model_infos):
    return [(minfo['policy'], minfo['path']) for minfo in model_infos]

cheetah_ppo = convert(cheetah_ppo)
cheetah_sac = convert(cheetah_sac)
walker_ppo = convert(walker_ppo)
walker_sac = convert(walker_sac)
pendulum_ppo = convert(pendulum_ppo)
pendulum_sac = convert(pendulum_sac)
cartpole_ppo = convert(cartpole_ppo)
cartpole_sac = convert(cartpole_sac)

domain2expert_info = {
    'cheetah-run': [cheetah_ppo[:3], cheetah_ppo[::4][:3], cheetah_sac[:3], cheetah_sac[::4][:3], cheetah_sac[-3:]],
    'walker-walk': [walker_ppo[:3], walker_ppo[::4][:3], walker_sac[:3], walker_sac[::4][:3], walker_sac[-3:]],
    'pendulum-swingup': [pendulum_ppo[:3], pendulum_ppo[::4][:3], pendulum_sac[:3], pendulum_sac[-3:]],
    'cartpole-swingup': [cartpole_ppo[:3], cartpole_ppo[::4][:3], cartpole_sac[:3], cartpole_sac[::4][:3], cartpole_sac[-3:]]}
# domain2expertsteps = {
#     # 'cheetah-run': [[100], [100, 70], [100, 70, 40], [100, 70, 40, 20]],
#     'cheetah-run': [[100, 70, 40]],
#     'walker-walk': [[190, 150, 100, 80], [150, 100, 80, 50], [130, 100, 80, 40]],
#     'pendulum-swingup': [[200, 150], [200, 150, 100], [200, 150, 100, 50]],
#     'cartpole-swingup': [[400, 300, 200, 40], [400, 140, 80], [400, 160, 60]]
# }
domain2ase_sigma = {
    'cheetah-run': 2.5,
    'cartpole-swingup': 0.25,
    'walker-walk': 10,
    'pendulum-swingup': 0.25,
}

ase_sigmas_on_cheetah = [0.5 * (i + 1) for i in range(20)]

def maybe_toint(val):
    if val.is_integer():
        return int(val)
    return val


def get_aps_vs_ase_query_set(
        domain,
        expert_steps,
        algorithms=["lops-aps", "lops-aps-ase", "mamba", "pg-gae"],
        learner_pi=["all"],
        group="original",
        ase_sigmas: Optional[dict] = None,
        alg2group: Optional[dict] = None
):

    domain = f"dmc:{domain.capitalize()}-v1"
    # print("given expert steps", domain, expert_steps)

    # HACK: for the first set of runs, I didn't specify groupname
    group = {'$ne': 'sigmas'} if group == 'original' else {'$eq': group}

    ase_sigmas = {} if ase_sigmas is None else {'config.ase_sigma': {'$in': ase_sigmas}}

    alg2query = {
        'lops-aps': {
            'config.algorithm': {'$eq': 'lops-aps'},
            'config.use_riro_for_learner_pi': {'$in': learner_pi},
            'config.load_expert_step': {'$eq': expert_steps},
            'group': {'$eq': alg2group['lops-aps']}
        },
        'lops-aps-ase': {
            'config.algorithm': {'$eq': 'lops-aps-ase'},
            'config.use_riro_for_learner_pi': {'$eq': 'all'},
            'config.load_expert_step': {'$eq': expert_steps},
            **ase_sigmas,
            'group': {'$eq': alg2group['lops-aps-ase']}
        },
    }
    return {
        "$and": [{
            "$or": [alg2query[alg] for alg in algorithms],
            "$and": [{
                'config.env_name': {'$eq': domain},
            }]
        }]
    }


def get_query_set(
    domain,
    expert_steps,
    algorithms=["lops-aps", "lops-aps-ase", "mamba", "pg-gae"],
    learner_pi=["all"],
    ase_learner_pi=["all"],
    group="original",
    ase_sigmas: Optional[dict] = None,
    aps_ase_extra: dict = {},
):
    domain = f"dmc:{domain.capitalize()}-v1"
    # print("given expert steps", domain, expert_steps)

    # HACK: for the first set of runs, I didn't specify groupname
    group = {'$ne': 'sigmas'} if group == 'original' else {'$eq': group}

    ase_sigmas = {} if ase_sigmas is None else {'config.ase_sigma': {'$in': ase_sigmas}}

    alg2query = {
        'lops-aps': {
            'config.algorithm': {'$eq': 'lops-aps'},
            'config.use_riro_for_learner_pi': {'$in': learner_pi},
            'config.load_expert_step': {'$eq': expert_steps},
        },
        'lops-aps-ase': {
            'config.algorithm': {'$eq': 'lops-aps-ase'},
            'config.use_riro_for_learner_pi': {'$in': ase_learner_pi},
            'config.load_expert_step': {'$eq': expert_steps},
            **ase_sigmas,
            **aps_ase_extra,
        },
        'mamba': {
            'config.algorithm': {'$eq': 'mamba'},
            'config.use_riro_for_learner_pi': {'$eq': 'none'},
            'config.load_expert_step': {'$eq': expert_steps},
        },
        'pg-gae': {
            'config.algorithm': {'$eq': 'pg-gae'},
            'config.use_riro_for_learner_pi': {'$eq': 'none'},
            'config.load_expert_step': {'$eq': [0]},
        }

    }
    return {
        "$and": [{
            "$or": [alg2query[alg] for alg in algorithms],
            "$and": [{
                'config.env_name': {'$eq': domain},
                'group': group,
                # 'created_at': {
                #     "$lt": '2023-02-15T2000',
                # }
                # 'created_at': {
                #     "$lt": '2023-02-17T1000',
                # },
            }]
        }]
    }


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--user", default='anonymous', help='wandb user name')
    parser.add_argument('--proj', default='rpi-manual_sweep', help='wandb project name')
    parser.add_argument("--format", default='pdf', choices=['pdf', 'png'], help="pdf or png")
    parser.add_argument("--dry-run", action='store_true')
    parser.add_argument("--force", action='store_true', help='If true, overwrite the existing plot')
    parser.add_argument("--use-stderr", action='store_true', help='If true, use standard error')
    parser.add_argument("--notitle", action='store_true')
    parser.add_argument("--nolegend", action='store_true')
    args = parser.parse_args()

    api = wandb.Api()

    domains = ['cheetah-run', 'walker-walk', 'pendulum-swingup', 'cartpole-swingup']
    # domains = ['cheetah-run']

    plot2config = {
        # Desc: main plot that shows performance between ours vs baselines
        # Axes: Training step vs Best-return
        # Domains: each
        # Lines: lops-aps-ase, lops-aps, mamba and pg-gae
        # Experts: each
        **{f"main-plot-{domain}-{i}": {
            "query": get_query_set(domain, expert_steps, algorithms=['lops-aps', 'mamba', 'pg-gae'], learner_pi=['all']),
            "xlabel": "Training step",
            "ylabel": "Best return",
            "group_keys": ["algorithm", "use_riro_for_learner_pi"],
            "ykey": "eval/best-so-far",
            "xkey": "step",
            "hbar": "expert_vals",
            "group2legend": {
                "mamba-none": "Mamba",
                "lops-aps-all": "LOPS-APS",
                "lops-aps-rollin": "LOPS-APS-ri",
                # "lops-aps-ase-all": "LOPS-APS-ASE",
                "pg-gae-none": "PPO-GAE"
            },
            "group2color": {
                "mamba-none": colors[1],
                "lops-aps-all": colors[0],
                # "lops-aps-rollin": colors[-1],
                # "lops-aps-ase-all": colors[2],
                "pg-gae-none": colors[2]
            },
            "show_title": not args.notitle,
            "show_legend": not args.nolegend,
            "plot_dir": "generated/main-plot",
        } for domain in domains
           for i, expert_steps in enumerate(domain2expert_info[domain])
           },
        # Desc: LOPS vs LOPS-APS performance plot
        # Axes: Training step vs Best-return
        # Domains: each
        # Lines: lops-aps-ase, lops-aps, mamba and pg-gae
        # Experts: each
        # **{f"plot-{domain}-{i}-2": {
        #     "query": get_query_set(domain, expert_steps, algorithms=['lops-aps', 'lops-aps-ase'],
        #                            learner_pi=['rollin'], ase_learner_pi=['rollin'],
        #                            aps_ase_extra={
        #                                'config.ase_uncertainty': {'$eq': 'value_std'},
        #                                'config.ase_sigma_ratio': {'$eq': 0.5}
        #                            }),
        #     "xlabel": "Training step",
        #     "ylabel": "Best return",
        #     "group_keys": ["algorithm", "use_riro_for_learner_pi"],
        #     "ykey": "eval/best-so-far",
        #     "xkey": "step",
        #     "hbar": "expert_vals",
        #     "group2legend": {
        #         "lops-aps-rollin": "LOPS-APS",
        #         "lops-aps-ase-rollin": "LOPS-ASE",
        #     },
        #     "group2color": {
        #         "lops-aps-rollin": colors[0],
        #         "lops-aps-ase-rollin": colors[1],
        #     },
        #     "plot_dir": "generated/plot-20230317",
        #     "show_title": False
        # } for domain in domains
        #    for i, expert_steps in enumerate(domain2expertsteps[domain])
        #    },
        # Desc: predicted stddev of selected expert value
        # Axes: Training step vs Standard Deviation
        # Domains: each
        # Lines: lops-aps-ase, lops-aps and mamba
        # Experts: each
        # **{f"stddev-plot-{domain}-{i}": {
        #     "query": get_query_set(domain, expert_steps, algorithms=['lops-aps', 'lops-aps-ase', 'mamba']),
        #     "xlabel": "Training step",
        #     "ylabel": "Standard Deviation",
        #     "group_keys": ["algorithm"],
        #     "ykey": "riro/selected_expert_val_std",
        #     "xkey": "step",
        #     "group2legend": {
        #         "mamba": "Mamba",
        #         "lops-aps": "LOPS-APS",
        #         # "lops-aps-rollin": "LOPS-APS-ri",
        #         "lops-aps-ase": "LOPS-APS-ASE",
        #         # "pg-gae-none": "PPO-GAE"
        #     },
        #     "group2color": {
        #         "mamba": colors[0],
        #         "lops-aps": colors[1],
        #         "lops-aps-ase": colors[2]
        #     },
        #     "plot_dir": "generated/stddev-plot",
        #     "extra_txt": f"sigma-{domain2ase_sigma[domain]}"
        # } for domain in domains
        #    for i, expert_steps in enumerate(domain2expertsteps[domain])
        #    },
        # **{f"stddev-plot-{domain}-{i}": {
        #     "query": get_query_set(domain, expert_steps, algorithms=['lops-aps', 'lops-aps-ase', 'mamba']),
        #     "xlabel": "Training step",
        #     "ylabel": "Standard Deviation",
        #     "group_keys": ["algorithm"],
        #     "ykey": "riro/selected_expert_val_std",
        #     "xkey": "step",
        #     "group2legend": {
        #         "mamba": "Mamba",
        #         "lops-aps": "LOPS-APS",
        #         # "lops-aps-rollin": "LOPS-APS-ri",
        #         "lops-aps-ase": "LOPS-ASE",
        #         # "pg-gae-none": "PPO-GAE"
        #     },
        #     "group2color": {
        #         "mamba": colors[0],
        #         "lops-aps": colors[1],
        #         "lops-aps-ase": colors[2]
        #     },
        #     "extra_txt": f"sigma-{domain2ase_sigma[domain]}",
        #     # "plot_dir": "generated/stddev-plot",
        #     "plot_dir": "generated/plot-20230317/stddev-plot",
        #     "show_title": False
        # } for domain in domains
        #    for i, expert_steps in enumerate(domain2expertsteps[domain])
        #    },

        # Desc: comparison between lops-aps vs lops-aps-ase on various sigma, across different experts
        # Axes: Training step vs Best-return
        # Domains: Cheetah-run only!
        # Lines: lops-aps-ase (for each sigma) and lops-aps
        # Experts: each
        # **{f"aps-vs-ase-cheetah-run-{i}-sigm{ase_sigma}": {
        #     "query": get_aps_vs_ase_query_set('cheetah-run',
        #                                       expert_steps,
        #                                       algorithms=['lops-aps', 'lops-aps-ase'],
        #                                       # algorithms=['lops-aps-ase'],
        #                                       learner_pi=['all'],
        #                                       group='sigmas',
        #                                       ase_sigmas=[maybe_toint(ase_sigma)],
        #                                       alg2group={'lops-aps': None,
        #                                                  'lops-aps-ase': 'sigmas'}
        #                                       ),
        #     "xlabel": "Training step",
        #     "ylabel": "Best return",
        #     "group_keys": ["algorithm", "ase_sigma"],
        #     "ykey": "eval/best-so-far",
        #     "xkey": "step",
        #     # "hbar": "expert_vals",
        #     "group2legend": {
        #         "lops-aps-0": "LOPS-APS",
        #         **{f'lops-aps-ase-{maybe_toint(0.5 * (i+1))}': "LOPS-APS-ASE" for i in range(20)}
        #     },
        #     "group2color": {
        #         "lops-aps-0": colors[0],
        #         **{f'lops-aps-ase-{maybe_toint(0.5 * (i+1))}': colors[1] for i in range(20)}
        #     },
        #     "plot_dir": "generated/aps-vs-ase",
        #     "extra_txt": f"sigma={ase_sigma}"
        # } for ase_sigma in ase_sigmas_on_cheetah
        #    for i, expert_steps in enumerate(domain2expertsteps['cheetah-run'])
        #    },

        # Desc: Bar chart
        # Axes: Training step vs Ratio of the expert selected
        # Domains: each
        # Lines: each expert
        # **{f"bar-chart-{domain}-{i}": {
        #     ...
        # }}
    }

    user = args.user
    project = args.proj
    name = f'{user}/{project}'

    for plot_name, config in plot2config.items():
        print('plot_name', plot_name)
        # print('config', config)
        query = config['query']
        # print('query\n', query)
        runs = api.runs(name, query)
        main(runs, plot_name, config, ext=f'.{args.format}', force=args.force, dry_run=args.dry_run, use_stderr=args.use_stderr)
